import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from image_synthesis.modeling.utils.misc import logits_top_k

class BaseTransformer(nn.Module):

    causal = True
    condition_seq_len = 0
    content_seq_len = 0
    
    @property
    def device(self):
        raise NotImplementedError


    @torch.no_grad()
    def _get_sample_from_logits(self, logits, filter_ratio=0.5, 
                                temperature=1.0):
        # logits: B x 1 x C
        batch_size = logits.shape[0]
        # import pdb; pdb.set_trace()
        filtered_logits = logits_top_k(logits, filter_ratio=filter_ratio, minimum=1) # B x 1 x C
        probs = F.softmax(filtered_logits * temperature, dim = -1) # B x 1 x C
        # import pdb; pdb.set_trace()
        sample = torch.multinomial(probs.view(-1, probs.shape[-1]), 1).view(batch_size, 1) # B x 1
                
        return sample 

    @torch.no_grad()
    def sample_causal(
        self,
        condition_token,
        condition_mask,
        content_token=None,
        filter_ratio = 0.5,
        temperature = 1.0,
        return_att_weight=False,
        return_logits=False,
        **kwargs,
    ):
        # import pdb; pdb.set_trace()
        batch_size = condition_token.shape[0]

        content_token = content_token if content_token is not None else torch.zeros((batch_size, 0)).to(condition_token) # epmty

        if return_att_weight:
            att_weight = {
                'condition_attention': torch.zeros(batch_size, self.content_seq_len, self.condition_seq_len),
                'content_attention': torch.zeros(batch_size, self.content_seq_len, self.content_seq_len)
            } # keep attn weight in a dict, the index of token is treated as the key

        current_length = condition_token.shape[1] + content_token.shape[1]
        if return_logits:
            num_cls = self.to_logits[-1].out_features
            all_logits = torch.zeros(batch_size, 0, num_cls).to(condition_token)

        for cur_len in range(current_length, self.content_seq_len + self.condition_seq_len):
            assert cur_len >= self.condition_seq_len, 'Please give the complete condition!'
            
            trans_out = self.forward(input={
                                    'condition_token': condition_token,
                                    'content_token': content_token,
                                    'condition_mask': condition_mask,
                                    }, 
                                return_loss=False, 
                                return_logits=True,
                                return_att_weight=return_att_weight)
            logits = trans_out['logits'][:, -1:, :] # B x 1 x C
            if return_att_weight:
                # import pdb; pdb.set_trace()
                cur_cont_len = cur_len-self.condition_seq_len
                att_cond_weight = trans_out['attention_weight'][:, -1:, :self.condition_seq_len] # B x 1 x cond_len
                att_cont_weight = trans_out['attention_weight'][:, -1:, -cur_cont_len:] # B x 1 x (cur_len-cond_len)
                
                # import pdb; pdb.set_trace()
                att_weight['condition_attention'][:, cur_len-self.condition_seq_len] = att_cond_weight[:, 0, :]
                if cur_cont_len > 0:
                    att_weight['content_attention'][:, cur_len-self.condition_seq_len, :att_cont_weight.shape[-1]] = att_cont_weight[:, 0, :]

            sample = self._get_sample_from_logits(logits, filter_ratio=filter_ratio, temperature=temperature)  # B x 1
            content_token = torch.cat((content_token, sample), dim=1)
            if return_logits:
                all_logits = torch.cat((all_logits, logits), dim=1) # B x l x 1
        output = {
            'content_token': content_token,
        }
        # import pdb; pdb.set_trace()
        if return_att_weight:
            output.update(att_weight)
        if return_logits:
            output['logits'] = all_logits

        # import pdb; pdb.set_trace()
        return output   
    
    @torch.no_grad()
    def sample_uncausal(
        self,
        *,
        condition_token=None,
        condition_mask=None,
        content_token,
        content_mask, # True for unmasked tokens, False for masked tokens
        filter_ratio = 0.5,
        temperature = 1.0,
        **kwargs,
    ):

        batch_size = content_token.shape[0] # for uncausal transformers, content_token is not None
        content_token = content_token.clone()
        content_mask = content_mask.clone()


        for i in range(0, self.content_seq_len):
            start_ = i 
            end_ = i+1
            itr_cont_mask = content_mask[:, start_:end_] # B x 1
            
            if itr_cont_mask.sum() < batch_size: # if there are some masked 
                logits = self.forward(
                    input={
                        'condition_token': condition_token,
                        'content_token': content_token,
                        'condition_mask': condition_mask,
                        'content_mask': content_mask,
                    },
                    return_loss=False,
                    return_logits=True 
                )['logits'][:, start_:end_] # B x 1 x C
                
                sample = self._get_sample_from_logits(logits, filter_ratio=filter_ratio, temperature=temperature)  # B x 1

                mask_ = ~itr_cont_mask # samples for masked tokens
                
                # change contents
                # import pdb; pdb.set_trace()
                content_token[:, start_:end_][mask_] = sample[mask_]
                content_mask[:, start_:end_][mask_] = True
                assert content_mask[:, :end_].sum() == batch_size * end_ # make sure that previous content tokens are all unmasked
        
        output = {
            'content_token': content_token
        }
        return output

    @torch.no_grad()
    def sample(self, **kwargs):
        if self.causal:
            return self.sample_causal(**kwargs)
        else:
            return self.sample_uncausal(**kwargs)
        